import torch
import torch.nn as nn
import torch.nn.functional as F


class PostProcessor(nn.Module):
    def __init__(self, num_classes=20, num_anchors=5):
        super(PostProcessor, self).__init__()
        self.num_anchors = num_anchors
        self.num_classes = num_classes

    def forward(self, out):
        # out -- tensor of shape (B, num_anchors * (5 + num_classes), H, W)
        bsize, _, h, w = out.size()

        # 5 + num_class tensor represents (t_x, t_y, t_h, t_w, t_c) and (class1_score, class2_score, ...)
        # reorganize the output tensor to shape (B, H * W * num_anchors, 5 + num_classes)
        out = (
            out.permute(0, 2, 3, 1)
            .contiguous()
            .view(bsize, h * w * self.num_anchors, 5 + self.num_classes)
        )

        # activate the output tensor
        # `sigmoid` for t_x, t_y, t_c; `exp` for t_h, t_w;
        # `softmax` for (class1_score, class2_score, ...)

        xy_pred = torch.sigmoid(out[:, :, 0:2])
        conf_pred = torch.sigmoid(out[:, :, 4:5])
        hw_pred = torch.exp(out[:, :, 2:4])
        class_score = out[:, :, 5:]
        class_pred = F.softmax(class_score, dim=-1)
        delta_pred = torch.cat([xy_pred, hw_pred], dim=-1)

        return {
            "delta_pred": delta_pred,
            "conf_pred": conf_pred,
            "class_pred": class_pred,
            "class_score": class_score,
            "height": h,
            "width": w,
        }
